import json

import jieba
import pickle
from __00__config import Config
import warnings

warnings.filterwarnings('ignore')

config = Config()
def predict_func(data):
	"""
	本函数用于后续前端编写中的客户端
	Args:
		data: 用户输入的数据

	Returns: 模型预测的标签

	"""
	# 读取本地模型
	with open(config.rf_model_save_path, 'rb') as f:
		model = pickle.load(f)
	# 读取本地词向量器
	with open(config.tfidf_model_save_path, 'rb') as f:
		tfidf = pickle.load(f)
	word = ' '.join(jieba.lcut(data['questions'])[0: 75])
	# 向量化
	features = tfidf.transform([word])
	# 预测
	y_predict_list = model.predict(features)
	# print(y_predict_list)
	y_predict_index = y_predict_list[0]
	with open(config.class_doc_path, 'r', encoding='utf-8') as j:
		class_doc = json.load(j)
	index2label = class_doc['idx_to_label']
	# 检查预测标签
	# print(index2label)
	# print(index2label[str(y_predict_index)])
	data["pred_class"] = index2label[str(y_predict_index)]
	return data


if __name__ == '__main__':
	res = predict_func({'questions': '我最近总是有还多的白带不知道是为什么请问女性白带增多的原因都有哪些呢？'})
	print(res)
